/*
 * Copyright (C) 2016 Square, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package okio

import app.cash.burst.InterceptTest
import java.io.IOException
import java.io.InterruptedIOException
import java.util.Random
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger
import kotlin.time.Duration.Companion.milliseconds
import okio.ByteString.Companion.decodeHex
import okio.HashingSink.Companion.sha1
import okio.TestUtil.assumeNotWindows
import org.junit.Assert.assertEquals
import org.junit.Assert.assertTrue
import org.junit.Assert.fail
import org.junit.Test

class PipeTest {
  @InterceptTest
  private val executorService = TestExecutor(2)

  @Test
  fun test() {
    val pipe = Pipe(6)
    pipe.sink.write(Buffer().writeUtf8("abc"), 3L)
    val source = pipe.source
    val readBuffer = Buffer()
    assertEquals(3L, source.read(readBuffer, 6L))
    assertEquals("abc", readBuffer.readUtf8())
    pipe.sink.close()
    assertEquals(-1L, source.read(readBuffer, 6L))
    source.close()
  }

  /**
   * A producer writes the first 16 MiB of bytes generated by `new Random(0)` to a sink, and a
   * consumer consumes them. Both compute hashes of their data to confirm that they're as expected.
   */
  @Test
  fun largeDataset() {
    val pipe = Pipe(1000L) // An awkward size to force producer/consumer exchange.
    val totalBytes = 16L * 1024L * 1024L
    val expectedHash = "7c3b224bea749086babe079360cf29f98d88262d".decodeHex()

    // Write data to the sink.
    val sinkHash = executorService.submit<ByteString> {
      val hashingSink = sha1(pipe.sink)
      val random = Random(0)
      val data = ByteArray(8192)
      val buffer = Buffer()
      var i = 0L
      while (i < totalBytes) {
        random.nextBytes(data)
        buffer.write(data)
        hashingSink.write(buffer, buffer.size)
        i += data.size.toLong()
      }
      hashingSink.close()
      hashingSink.hash
    }

    // Read data from the source.
    val sourceHash = executorService.submit<ByteString> {
      val blackhole = Buffer()
      val hashingSink = sha1(blackhole)
      val buffer = Buffer()
      while (pipe.source.read(buffer, Long.MAX_VALUE) != -1L) {
        hashingSink.write(buffer, buffer.size)
        blackhole.clear()
      }
      pipe.source.close()
      hashingSink.hash
    }
    assertEquals(expectedHash, sinkHash.get())
    assertEquals(expectedHash, sourceHash.get())
  }

  @Test
  fun sinkTimeout() {
    assumeNotWindows()
    val pipe = Pipe(3)
    pipe.sink.timeout().timeout(1000, TimeUnit.MILLISECONDS)
    pipe.sink.write(Buffer().writeUtf8("abc"), 3L)
    val start = now()
    try {
      pipe.sink.write(Buffer().writeUtf8("def"), 3L)
      fail()
    } catch (expected: InterruptedIOException) {
      assertEquals("timeout", expected.message)
    }
    assertElapsed(1000.0, start)
    val readBuffer = Buffer()
    assertEquals(3L, pipe.source.read(readBuffer, 6L))
    assertEquals("abc", readBuffer.readUtf8())
  }

  @Test
  fun sourceTimeout() {
    assumeNotWindows()
    val pipe = Pipe(3L)
    pipe.source.timeout().timeout(1000, TimeUnit.MILLISECONDS)
    val start = now()
    val readBuffer = Buffer()
    try {
      pipe.source.read(readBuffer, 6L)
      fail()
    } catch (expected: InterruptedIOException) {
      assertEquals("timeout", expected.message)
    }
    assertElapsed(1000.0, start)
    assertEquals(0, readBuffer.size)
  }

  /**
   * The writer is writing 12 bytes as fast as it can to a 3 byte buffer. The reader alternates
   * sleeping 1000 ms, then reading 3 bytes. That should make for an approximate timeline like
   * this:
   *
   * ```
   *    0: writer writes 'abc', blocks 0: reader sleeps until 1000
   * 1000: reader reads 'abc', sleeps until 2000
   * 1000: writer writes 'def', blocks
   * 2000: reader reads 'def', sleeps until 3000
   * 2000: writer writes 'ghi', blocks
   * 3000: reader reads 'ghi', sleeps until 4000
   * 3000: writer writes 'jkl', returns
   * 4000: reader reads 'jkl', returns
   * ```
   *
   *
   * Because the writer is writing to a buffer, it finishes before the reader does.
   */
  @Test
  fun sinkBlocksOnSlowReader() {
    val pipe = Pipe(3L)
    val position = AtomicInteger()

    executorService.submit {
      val buffer = Buffer()
      Thread.sleep(1000L)
      position.set(1)
      assertEquals(3, pipe.source.read(buffer, Long.MAX_VALUE))
      assertEquals("abc", buffer.readUtf8())
      Thread.sleep(1000L)
      position.set(2)
      assertEquals(3, pipe.source.read(buffer, Long.MAX_VALUE))
      assertEquals("def", buffer.readUtf8())
      Thread.sleep(1000L)
      position.set(3)
      assertEquals(3, pipe.source.read(buffer, Long.MAX_VALUE))
      assertEquals("ghi", buffer.readUtf8())
      Thread.sleep(1000L)
      position.set(4)
      assertEquals(3, pipe.source.read(buffer, Long.MAX_VALUE))
      assertEquals("jkl", buffer.readUtf8())
    }

    pipe.sink.write(Buffer().writeUtf8("abcdefghijkl"), 12)
    assertEquals(3, position.get())
  }

  @Test
  fun sinkWriteFailsByClosedReader() {
    val pipe = Pipe(3L)
    executorService.schedule(1000.milliseconds) {
      pipe.source.close()
    }
    val start = now()
    try {
      pipe.sink.write(Buffer().writeUtf8("abcdef"), 6)
      fail()
    } catch (expected: IOException) {
      assertEquals("source is closed", expected.message)
      assertElapsed(1000.0, start)
    }
  }

  @Test
  fun sinkFlushDoesntWaitForReader() {
    val pipe = Pipe(100L)
    pipe.sink.write(Buffer().writeUtf8("abc"), 3)
    pipe.sink.flush()
    val bufferedSource = pipe.source.buffer()
    assertEquals("abc", bufferedSource.readUtf8(3))
  }

  @Test
  fun sinkFlushFailsIfReaderIsClosedBeforeAllDataIsRead() {
    val pipe = Pipe(100L)
    pipe.sink.write(Buffer().writeUtf8("abc"), 3)
    pipe.source.close()
    try {
      pipe.sink.flush()
      fail()
    } catch (expected: IOException) {
      assertEquals("source is closed", expected.message)
    }
  }

  @Test
  fun sinkCloseFailsIfReaderIsClosedBeforeAllDataIsRead() {
    val pipe = Pipe(100L)
    pipe.sink.write(Buffer().writeUtf8("abc"), 3)
    pipe.source.close()
    try {
      pipe.sink.close()
      fail()
    } catch (expected: IOException) {
      assertEquals("source is closed", expected.message)
    }
  }

  @Test
  fun sinkClose() {
    val pipe = Pipe(100L)
    pipe.sink.close()
    try {
      pipe.sink.write(Buffer().writeUtf8("abc"), 3)
      fail()
    } catch (expected: IllegalStateException) {
      assertEquals("closed", expected.message)
    }
    try {
      pipe.sink.flush()
      fail()
    } catch (expected: IllegalStateException) {
      assertEquals("closed", expected.message)
    }
  }

  @Test
  fun sinkMultipleClose() {
    val pipe = Pipe(100L)
    pipe.sink.close()
    pipe.sink.close()
  }

  @Test
  fun sinkCloseDoesntWaitForSourceRead() {
    val pipe = Pipe(100L)
    pipe.sink.write(Buffer().writeUtf8("abc"), 3)
    pipe.sink.close()
    val bufferedSource = pipe.source.buffer()
    assertEquals("abc", bufferedSource.readUtf8())
    assertTrue(bufferedSource.exhausted())
  }

  @Test
  fun sourceClose() {
    val pipe = Pipe(100L)
    pipe.source.close()
    try {
      pipe.source.read(Buffer(), 3)
      fail()
    } catch (expected: IllegalStateException) {
      assertEquals("closed", expected.message)
    }
  }

  @Test
  fun sourceMultipleClose() {
    val pipe = Pipe(100L)
    pipe.source.close()
    pipe.source.close()
  }

  @Test
  fun sourceReadUnblockedByClosedSink() {
    val pipe = Pipe(3L)
    executorService.schedule(1000.milliseconds) {
      pipe.sink.close()
    }
    val start = now()
    val readBuffer = Buffer()
    assertEquals(-1, pipe.source.read(readBuffer, Long.MAX_VALUE))
    assertEquals(0, readBuffer.size)
    assertElapsed(1000.0, start)
  }

  /**
   * The writer has 12 bytes to write. It alternates sleeping 1000 ms, then writing 3 bytes. The
   * reader is reading as fast as it can. That should make for an approximate timeline like this:
   *
   * ```
   *    0: writer sleeps until 1000
   *    0: reader blocks
   * 1000: writer writes 'abc', sleeps until 2000
   * 1000: reader reads 'abc'
   * 2000: writer writes 'def', sleeps until 3000
   * 2000: reader reads 'def'
   * 3000: writer writes 'ghi', sleeps until 4000
   * 3000: reader reads 'ghi'
   * 4000: writer writes 'jkl', returns
   * 4000: reader reads 'jkl', returns
   * ```
   */
  @Test
  fun sourceBlocksOnSlowWriter() {
    val pipe = Pipe(100L)
    val position = AtomicInteger()

    executorService.submit {
      Thread.sleep(1000L)
      position.set(1)
      pipe.sink.write(Buffer().writeUtf8("abc"), 3)
      Thread.sleep(1000L)
      position.set(2)
      pipe.sink.write(Buffer().writeUtf8("def"), 3)
      Thread.sleep(1000L)
      position.set(3)
      pipe.sink.write(Buffer().writeUtf8("ghi"), 3)
      Thread.sleep(1000L)
      position.set(4)
      pipe.sink.write(Buffer().writeUtf8("jkl"), 3)
    }

    val readBuffer = Buffer()
    assertEquals(3, pipe.source.read(readBuffer, Long.MAX_VALUE))
    assertEquals("abc", readBuffer.readUtf8())
    assertEquals(1, position.get())

    assertEquals(3, pipe.source.read(readBuffer, Long.MAX_VALUE))
    assertEquals("def", readBuffer.readUtf8())
    assertEquals(2, position.get())

    assertEquals(3, pipe.source.read(readBuffer, Long.MAX_VALUE))
    assertEquals("ghi", readBuffer.readUtf8())
    assertEquals(3, position.get())

    assertEquals(3, pipe.source.read(readBuffer, Long.MAX_VALUE))
    assertEquals("jkl", readBuffer.readUtf8())
    assertEquals(4, position.get())
  }

  /** Returns the nanotime in milliseconds as a double for measuring timeouts. */
  private fun now(): Double {
    return System.nanoTime() / 1000000.0
  }

  /**
   * Fails the test unless the time from start until now is duration, accepting differences in
   * -50..+450 milliseconds.
   */
  private fun assertElapsed(duration: Double, start: Double) {
    assertEquals(duration, now() - start - 200.0, 250.0)
  }
}
